{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass, field\n",
    "from typing import List, Optional, Dict\n",
    "\n",
    "from datasets import load_dataset\n",
    "from transformers import PreTrainedTokenizerFast, PhiForCausalLM, TrainingArguments, TrainerCallback\n",
    "import pandas as pd\n",
    "import time\n",
    "import torch \n",
    "from trl import DPOTrainer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. 定义sft模型路径及dpo数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "dpo_file = './data/dpo_train_data.json'\n",
    "tokenizer_dir = './model_save/tokenizer/'\n",
    "sft_from_checkpoint_file = './model_save/sft/'\n",
    "model_save_dir = './model_save/dpo/'\n",
    "max_seq_len = 320"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 加载数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vicab size: 35840\n"
     ]
    }
   ],
   "source": [
    "tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir)\n",
    "print(f\"vicab size: {len(tokenizer)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(path='json', data_files=dpo_file, split='train', cache_dir='.cache')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'rejected': '航空公司的飞行计划通常基于多种因素进行考虑。航班计划可能会因多种情况而受到更改和延误：例如恶劣气候状况、机械问题和机组人员的操作等情况可能导致航班计划发生变化。如果您对您的特定航班有任何疑问，建议联系航空公司以进行进一步咨询或获取更新的航班时间。同时您可以通过航空服务咨询平台，查询航班的动态，获得更多的帮助。\\n用户：好的，我理解航空公司的安排可能会因很多情况而发生变化。谢谢你的建议和信息。\\nChatBot : 您不用谢！如果您有任何进一步的问题或疑虑，请继续与我沟通，我非常乐意回答您。\\n',\n",
       " 'chosen': '确认已收到你的要求。我会根据用户的需求提供相关信息和建议。请问有什么具体的问题和咨询需要解决呢？\\n',\n",
       " 'prompt': '这是航空服务咨询的问题，请给出确定的回复'}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. 数据集token格式化\n",
    "将dpo数据集三列数据添加上`eos`token，`bos`可加可不加"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_prompt_and_responses(samples: dict[str, str]) -> Dict[str, str]:\n",
    "        prompts, chosens, rejects = [], [], []\n",
    "        batch_size = len(samples['prompt'])\n",
    "        for i in range(batch_size):\n",
    "            # add an eos token for signal that end of sentence, using in generate.\n",
    "            prompts.append(f\"[BOS]{samples['prompt'][i]}[EOS]\")\n",
    "            chosens.append(f\"[BOS]{samples['chosen'][i]}[EOS]\")\n",
    "            rejects.append(f\"[BOS]{samples['rejected'][i]}[EOS]\")\n",
    "\n",
    "        return {\n",
    "              'prompt': prompts,\n",
    "              'chosen': chosens,\n",
    "              'rejected':rejects,\n",
    "        }\n",
    "\n",
    "dataset = dataset.map(split_prompt_and_responses, batched=True,).shuffle(2333)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4. 加载模型\n",
    "`model`和`model_ref`开始时是同一个模型，只训练`model`的参数，`model_ref`参数保存不变"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Phi-2 size: 193.7M parameters\n"
     ]
    }
   ],
   "source": [
    "\n",
    "model = PhiForCausalLM.from_pretrained(sft_from_checkpoint_file)\n",
    "model_ref = PhiForCausalLM.from_pretrained(sft_from_checkpoint_file)\n",
    "\n",
    "model_size = sum(t.numel() for t in model.parameters())\n",
    "print(f\"Phi-2 size: {model_size / 1000**2:.1f}M parameters\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5. 定义训练中的回调函数\n",
    "清空cuda缓存，dpo要加载两个模型，显存占用较大，这能有效缓解低显存机器显存缓慢增长的问题"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EmptyCudaCacheCallback(TrainerCallback):\n",
    "    log_cnt = 0\n",
    "    def on_log(self, args, state, control, logs=None, **kwargs):\n",
    "        self.log_cnt += 1\n",
    "        if self.log_cnt % 5 == 0:\n",
    "            torch.cuda.empty_cache()\n",
    "            \n",
    "empty_cuda_cahce = EmptyCudaCacheCallback()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 6. 定义训练参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = TrainingArguments(\n",
    "    output_dir=model_save_dir,\n",
    "    per_device_train_batch_size=2,\n",
    "    gradient_accumulation_steps=16,\n",
    "    num_train_epochs=4,\n",
    "    weight_decay=0.1,\n",
    "    warmup_steps=1000,\n",
    "    learning_rate=2e-5,\n",
    "    save_steps=2000,\n",
    "    save_total_limit=3,\n",
    "    report_to='tensorboard',\n",
    "    bf16=True,\n",
    "    logging_steps=10,\n",
    "    log_level='info',\n",
    "    logging_first_step=True,\n",
    "    optim=\"adafactor\",\n",
    "    remove_unused_columns=False,\n",
    ")\n",
    "\n",
    "trainer = DPOTrainer(\n",
    "    model,\n",
    "    model_ref,\n",
    "    args=args,\n",
    "    beta=0.1,\n",
    "    train_dataset=dataset,\n",
    "    tokenizer=tokenizer,\n",
    "    callbacks=[empty_cuda_cahce],\n",
    "    max_length=max_seq_len * 2 + 16, # 16 for eos bos\n",
    "    max_prompt_length=max_seq_len,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 7. 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 8. 保存日志及模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to ./model_save/dpo/\n",
      "Configuration saved in ./model_save/dpo/config.json\n",
      "Configuration saved in ./model_save/dpo/generation_config.json\n",
      "Model weights saved in ./model_save/dpo/pytorch_model.bin\n",
      "tokenizer config file saved in ./model_save/dpo/tokenizer_config.json\n",
      "Special tokens file saved in ./model_save/dpo/special_tokens_map.json\n"
     ]
    }
   ],
   "source": [
    "loss_log = pd.DataFrame(trainer.state.log_history)\n",
    "loss_log.to_csv(f\"./logs/dpo_train_log_{time.strftime('%Y%m%d-%H%M')}.csv\")\n",
    "\n",
    "\n",
    "trainer.save_model(model_save_dir)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
